import random  # to set the python random seed
import numpy  # to set the numpy random seed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import warnings
warnings.filterwarnings('ignore')

import wandb
wandb.init(project="thorough-pytorch",
           name="wandb_demo",)

# 超参数设置
config = wandb.config  # config的初始化
config.batch_size = 64  
config.test_batch_size = 10 
config.epochs = 5  
config.lr = 0.01 
config.momentum = 0.1  
config.use_cuda = False  
config.seed = 2043  
config.log_interval = 10 

# 设置随机数
def set_seed(seed):
    random.seed(config.seed)      
    torch.manual_seed(config.seed) 
    numpy.random.seed(config.seed) 


def train(model, device, train_loader, optimizer):
    model.train()

    for batch_id, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能
def test(model, device, test_loader, classes):
    model.eval()
    test_loss = 0
    correct = 0
    example_images = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            criterion = nn.CrossEntropyLoss()
            test_loss += criterion(output, target).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            example_images.append(wandb.Image(
                data[0], caption="Pred:{} Truth:{}".format(classes[pred[0].item()], classes[target[0]])))

   # 使用wandb.log 记录你想记录的指标
    wandb.log({
        "Examples": example_images,
        "Test Accuracy": 100. * correct / len(test_loader.dataset),
        "Test Loss": test_loss
    })

wandb.watch_called = False 


def main():
    use_cuda = config.use_cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # 设置随机数
    set_seed(config.seed)
    torch.backends.cudnn.deterministic = True

    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载数据
    train_loader = DataLoader(datasets.CIFAR10(
        root='dataset',
        train=True,
        download=True,
        transform=transform
    ), batch_size=config.batch_size, shuffle=True, **kwargs)

    test_loader = DataLoader(datasets.CIFAR10(
        root='dataset',
        train=False,
        download=True,
        transform=transform
    ), batch_size=config.batch_size, shuffle=False, **kwargs)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    model = resnet18(pretrained=True).to(device)
    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)

    wandb.watch(model, log="all")
    for epoch in range(1, config.epochs + 1):
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader, classes)

    # 本地和云端模型保存
    torch.save(model.state_dict(), 'model.pth')
    wandb.save('model.pth')


if __name__ == '__main__':
    main()